We can read in the embedding file, which has the outlier protein sequences detected by manual inspection of MDS plots

embeds <- readRDS("data/tmp/embeds_with_mds.rds")
table(embeds$ManualOutlier)

 FALSE   TRUE 
647389   3367 
dim(embeds)
[1] 650756    969
embed_cols <- grep("embedding", colnames(embeds), value=TRUE)

clean_embeds <- embeds[ManualOutlier == FALSE]
head(clean_embeds[,1:5])
#write_parquet(clean_embeds, "data/clean_embeds.parquet")

#clean_seq <- clean_embeds[,c("ID", "Taxonomy", "Gene", "AA_seq")]
#write_parquet(clean_seq, "data/clean_AA_seqs.parquet")
#head(clean_seq)

And now we can read in the data with phenotypic values

df <- read_parquet("data/processed_data.parquet")
setDT(df)

#pheno <- clean_data[,c("ID", "pheno_Topt_site_p50")]
  
#write_parquet(pheno, "data/pheno_topt_clean.parquet")

Now, in the meantime, I wish to analyze the correlation between the embeddings and my phenotypes.

genes <- unique(clean_embeds$Gene)
all_cors <- list()
for (gene in genes) {
  gene_data <- clean_embeds[grep(gene, clean_embeds$Gene),]
  pheno <- df[,c("ID", "pheno_wc2.1_2.5m_bio_8_p50")]
  
  gene_data <- merge(gene_data, pheno, by="ID")
  
  cors <- sapply(embed_cols, function(col)
    cor(gene_data[[col]], gene_data$pheno_wc2.1_2.5m_bio_8_p50, use="complete.obs")
  )
  hist(cors, main=gene)
  all_cors[[gene]] <- cors
}
par(mfrow=c(1,2))
hist(all_cors$psbN, main="Cor of psbN embeds with bio8")
hist(all_cors$rbcL, main="rbcL")
gene_data <- clean_embeds[grep("psaC", clean_embeds$Gene),]
par(mfrow=c(1,2))
hist(gene_data$MDS1)
plot(gene_data$MDS1, gene_data$MDS2, main="psaC MDS results")
#hist(gene_data$MDS2)
par(mfrow=c(1,1))
for (ord in unique(df$Order)) {
  order_ids <- df[grep(ord, df$Order),"ID"]
  order_subset <- gene_data[gene_data$ID %in% order_ids$ID, ]
  hist(order_subset$MDS1,main=ord,xlim=c(-.1,.1))
}
gene_data <- clean_embeds[grep("psaC", clean_embeds$Gene),]
merged <- merge(gene_data, df[,c("ID","Order")], by="ID")

boxplot(MDS1 ~ Order, data=merged,
        main="psaC MDS1 by Order",
        las=2, outline=FALSE)
library(pheatmap)
mat <- do.call(rbind, all_cors)
rownames(mat) <- names(all_cors)
pheatmap(mat, color=colorRampPalette(c("blue","white","red"))(100))
# build correlation matrix: genes x embedding dimensions
mat <- do.call(rbind, all_cors)
rownames(mat) <- names(all_cors)

# similarity between genes: correlation of their correlation profiles
gene_sim <- cor(t(mat), use="pairwise.complete.obs")

# hierarchical clustering heatmap
pheatmap(gene_sim, 
         main="Similarity of gene embeddings wrt bio8",
         color=colorRampPalette(c("blue","white","red"))(100))

hist(gene_sim)
offdiag <- gene_sim[upper.tri(gene_sim)]

hist(offdiag, breaks=30,
     main="Stability of embedding correlations across genes",
     xlab="Pairwise correlation",
     col="skyblue")
abline(v=mean(offdiag, na.rm=TRUE), col="red", lwd=2)
hist(length(clean_embeds$psaC_CDS)/3)
mat <- do.call(rbind, all_cors)
rownames(mat) <- names(all_cors)

# per-embedding stats across genes
embed_stats <- data.frame(
  dim = colnames(mat),
  mean_cor = apply(mat, 2, mean, na.rm=TRUE),
  mean_abs_cor = apply(mat, 2, function(x) mean(abs(x), na.rm=TRUE)),
  sd_cor = apply(mat, 2, sd, na.rm=TRUE)
)

plot(embed_stats$mean_cor, embed_stats$sd_cor,
     xlab="mean of correlation across genes",
     ylab="sd of correlation across genes")
abline(a=0,b=1)
# pick embedding with strongest + most stable signal
best <- embed_stats[order(-embed_stats$mean_abs_cor, embed_stats$sd_cor), ][1, ]
best
library(stats)

results <- list()
for (gene in genes) {
  gene_data <- clean_embeds[Gene == gene]
  pheno <- df[,c("ID","pheno_wc2.1_2.5m_bio_8_p50")]
  merged <- merge(gene_data, pheno, by="ID")
  
  cors <- sapply(embed_cols, function(col)
    cor(merged[[col]], merged$pheno_wc2.1_2.5m_bio_8_p50, use="complete.obs"))
  
  top_dims <- names(sort(abs(cors), decreasing=TRUE))[1:5]
  formula_str <- paste("pheno_wc2.1_2.5m_bio_8_p50 ~", paste(top_dims, collapse=" + "))
  fit <- lm(as.formula(formula_str), data=merged)
  
  plot(fit$fitted.values, merged$pheno_wc2.1_2.5m_bio_8_p50)
  
  results[[gene]] <- list(
    cors=cors,
    top_dims=top_dims,
    model=summary(fit)
  )
}


results$psaC$top_dims
[1] "embedding_59"  "embedding_118" "embedding_459" "embedding_37"  "embedding_220"
results$psaC$model

Call:
lm(formula = as.formula(formula_str), data = merged)

Residuals:
    Min      1Q  Median      3Q     Max 
-25.935  -4.952   1.659   4.864  16.783 

Coefficients:
              Estimate Std. Error t value Pr(>|t|)    
(Intercept)     10.073      4.149   2.428 0.015204 *  
embedding_59   751.717    117.098   6.420 1.42e-10 ***
embedding_118 -155.406    153.138  -1.015 0.310218    
embedding_459 -633.788    188.642  -3.360 0.000783 ***
embedding_37  -252.371    131.748  -1.916 0.055447 .  
embedding_220  245.807    290.302   0.847 0.397165    
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 6.416 on 10761 degrees of freedom
Multiple R-squared:  0.02567,   Adjusted R-squared:  0.02521 
F-statistic: 56.69 on 5 and 10761 DF,  p-value: < 2.2e-16
# per-gene variable selection and building a combined design matrix
library(data.table)
library(stats)

# ensure data.tables
setDT(clean_embeds)
setDT(df)

# phenotype column name (adjust if you want a different pheno)
pheno_col <- "pheno_wc2.1_2.5m_bio_8_p50"

# number of top dims per gene
n_top <- 5

# container for per-gene selected data.tables (ID + renamed top embeds)
sel_list <- vector("list", length(genes))
names(sel_list) <- genes

for (gene in genes) {
  # subset gene
  gdt <- clean_embeds[Gene == gene, c("ID", embed_cols), with = FALSE]
  # merge with phenotype (inner join to ensure measurable correlation)
  gdt <- merge(gdt, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = FALSE, all.y = FALSE)
  # if too few rows, skip (or store NA)
  if (nrow(gdt) < 5) {
    warning(sprintf("Gene %s has only %d rows; skipping.", gene, nrow(gdt)))
    next
  }
  # compute correlations (use complete.obs)
  cors <- sapply(embed_cols, function(col) cor(gdt[[col]], gdt$pheno, use = "complete.obs"))
  # pick top dims by absolute correlation (handle if fewer than n_top dims available)
  available <- names(cors)[!is.na(cors)]
  k <- min(n_top, length(available))
  top_dims <- names(sort(abs(cors[available]), decreasing = TRUE))[1:k]
  # select ID + these dims and rename dims to gene__dim
  sel <- gdt[, c("ID", top_dims), with = FALSE]
  newnames <- setNames(top_dims, paste0(gene, "__", top_dims))
  setnames(sel, old = top_dims, new = paste0(gene, "__", top_dims))
  # keep only ID + renamed columns (drop phenotype copy)
  sel <- sel[, c("ID", names(newnames)), with = FALSE]
  sel_list[[gene]] <- sel
}

# remove genes we skipped
sel_list <- sel_list[!sapply(sel_list, is.null)]

# merge all per-gene tables by ID, keeping only IDs present in ALL (intersection)
if (length(sel_list) == 0) stop("No genes with selected dims found.")
combined <- Reduce(function(a, b) merge(a, b, by = "ID", all = FALSE), sel_list)

# bring phenotype back
combined <- merge(combined, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = TRUE, all.y = FALSE)

# quick checks
cat("Samples (IDs) in combined matrix:", nrow(combined), "\n")
Samples (IDs) in combined matrix: 5907 
cat("Number of predictors:", ncol(combined) - 2, "(ID and pheno excluded)\n")  # -2 for ID and pheno
Number of predictors: 305 (ID and pheno excluded)
if (inherits(lm_fit, "try-error")) {
  warning("lm failed (probably too many predictors). See glmnet example below.")
} else {
  print(summary(lm_fit))
}

Call:
lm(formula = lm_formula, data = combined)

Residuals:
     Min       1Q   Median       3Q      Max 
-24.1608  -2.9069   0.4772   3.1734  22.0632 

Coefficients:
                       Estimate Std. Error t value Pr(>|t|)    
(Intercept)             76.8254    55.6181   1.381 0.167241    
atpA__embedding_876   -208.8354   242.6664  -0.861 0.389503    
atpA__embedding_94    -575.7553   160.9851  -3.576 0.000351 ***
atpA__embedding_741   -179.2769   104.7317  -1.712 0.086994 .  
atpA__embedding_201    777.5276   123.6092   6.290 3.41e-10 ***
atpA__embedding_232   -262.7959   141.4287  -1.858 0.063200 .  
atpB__embedding_224    231.1105   145.2126   1.592 0.111546    
atpB__embedding_514    910.9106   238.5938   3.818 0.000136 ***
atpB__embedding_634    -93.0585   138.4092  -0.672 0.501393    
atpB__embedding_695    269.1380   145.7439   1.847 0.064851 .  
atpB__embedding_582    531.0567   213.9233   2.482 0.013077 *  
atpE__embedding_781   -107.8754   137.9126  -0.782 0.434130    
atpE__embedding_254     35.7627   102.8618   0.348 0.728096    
atpE__embedding_17      65.8292   111.7574   0.589 0.555860    
atpE__embedding_633    187.9588   114.2747   1.645 0.100068    
atpE__embedding_548    142.9291   110.8520   1.289 0.197323    
atpF__embedding_739     62.8667   128.4662   0.489 0.624603    
atpF__embedding_721   -130.7040    68.3097  -1.913 0.055748 .  
atpF__embedding_522    -46.2784   100.6542  -0.460 0.645694    
atpF__embedding_720    -95.5329    53.2440  -1.794 0.072828 .  
atpF__embedding_234    -96.1766    79.0348  -1.217 0.223698    
atpH__embedding_248   -392.6980   664.4051  -0.591 0.554509    
atpH__embedding_583   1072.0269   906.1453   1.183 0.236834    
atpH__embedding_46    -726.1298  1140.4567  -0.637 0.524346    
atpH__embedding_310  -1483.3559   789.0042  -1.880 0.060155 .  
atpH__embedding_256    449.7932   761.7268   0.590 0.554885    
atpI__embedding_282   -198.1053   148.1331  -1.337 0.181164    
atpI__embedding_349    -86.2666   135.7952  -0.635 0.525279    
atpI__embedding_823   -548.2127   190.5590  -2.877 0.004032 ** 
atpI__embedding_599    153.3643   173.1258   0.886 0.375734    
atpI__embedding_324   -320.3872   169.5294  -1.890 0.058828 .  
ccsA__embedding_736   -300.9649    87.9722  -3.421 0.000628 ***
ccsA__embedding_563   -175.8039    60.1767  -2.921 0.003498 ** 
ccsA__embedding_196   -142.1579    61.1807  -2.324 0.020184 *  
ccsA__embedding_710    165.9675    68.8158   2.412 0.015907 *  
ccsA__embedding_744     54.4999    77.4317   0.704 0.481558    
cemA__embedding_764     98.3458   103.4912   0.950 0.342010    
cemA__embedding_402    101.9166    62.3800   1.634 0.102356    
cemA__embedding_480     38.0450    51.3430   0.741 0.458727    
cemA__embedding_836    156.2334    84.3986   1.851 0.064202 .  
cemA__embedding_730   -178.8383    69.4146  -2.576 0.010010 *  
matK__embedding_867    186.8447    72.8111   2.566 0.010309 *  
matK__embedding_869     80.5346    99.9816   0.805 0.420569    
matK__embedding_655   -270.9279    88.1532  -3.073 0.002127 ** 
matK__embedding_272     61.9943    85.9940   0.721 0.470993    
matK__embedding_637    -96.6917   111.5608  -0.867 0.386134    
ndhA__embedding_204    139.1071    91.9982   1.512 0.130574    
ndhA__embedding_328    272.3231   113.4779   2.400 0.016437 *  
ndhA__embedding_924   -189.1086   117.8001  -1.605 0.108476    
ndhA__embedding_584    108.4594    94.1441   1.152 0.249347    
ndhA__embedding_637    339.3575    97.9665   3.464 0.000536 ***
ndhB__embedding_16      63.4961   251.6460   0.252 0.800801    
ndhB__embedding_829   -330.2832   273.7211  -1.207 0.227621    
ndhB__embedding_19    -286.1228   281.0232  -1.018 0.308652    
ndhB__embedding_701     13.1022   247.8648   0.053 0.957845    
ndhB__embedding_404    -36.2425   224.9070  -0.161 0.871986    
ndhC__embedding_172     -5.2171   112.4138  -0.046 0.962985    
ndhC__embedding_495   -381.2100   149.7819  -2.545 0.010951 *  
ndhC__embedding_619     80.1857   149.1111   0.538 0.590765    
ndhC__embedding_656    207.2433   158.5764   1.307 0.191301    
ndhC__embedding_939    100.8045   203.1951   0.496 0.619845    
ndhD__embedding_610     -5.3162    96.6444  -0.055 0.956134    
ndhD__embedding_154     65.7973   102.9415   0.639 0.522737    
ndhD__embedding_16     -22.1424    76.1572  -0.291 0.771257    
ndhD__embedding_126   -333.0086   107.4735  -3.099 0.001955 ** 
ndhD__embedding_493    129.2113    89.0514   1.451 0.146843    
ndhE__embedding_394   -235.0752   130.0469  -1.808 0.070719 .  
ndhE__embedding_291    215.9190    87.1772   2.477 0.013287 *  
ndhE__embedding_36     250.5818   107.1751   2.338 0.019419 *  
ndhE__embedding_327    194.3733   111.9194   1.737 0.082491 .  
ndhE__embedding_792    207.9030   117.2270   1.774 0.076199 .  
ndhG__embedding_936    -21.4389    69.0997  -0.310 0.756374    
ndhG__embedding_616   -235.8352   102.3305  -2.305 0.021223 *  
ndhG__embedding_689    150.3810   100.8533   1.491 0.135995    
ndhG__embedding_651    138.8690    58.8018   2.362 0.018228 *  
ndhG__embedding_600   -154.4796    92.1588  -1.676 0.093748 .  
ndhH__embedding_448     27.2937   244.6118   0.112 0.911161    
ndhH__embedding_958    154.8204   218.6563   0.708 0.478942    
ndhH__embedding_844    269.3165   131.0652   2.055 0.039942 *  
ndhH__embedding_282     34.9133   162.1551   0.215 0.829535    
ndhH__embedding_180   -154.7166   117.3752  -1.318 0.187512    
ndhI__embedding_356     -1.5720   105.6532  -0.015 0.988129    
ndhI__embedding_247    549.0817   124.7357   4.402 1.09e-05 ***
ndhI__embedding_729      5.9835   160.1021   0.037 0.970189    
ndhI__embedding_854    232.8629   113.8125   2.046 0.040801 *  
ndhI__embedding_802   -277.7352   107.0273  -2.595 0.009484 ** 
ndhJ__embedding_910     65.5178   120.5850   0.543 0.586923    
ndhJ__embedding_478   -370.7750   122.5513  -3.025 0.002494 ** 
ndhJ__embedding_719    822.2006   169.5346   4.850 1.27e-06 ***
ndhJ__embedding_220     60.0519   124.3602   0.483 0.629195    
ndhJ__embedding_728   -263.0135    91.7493  -2.867 0.004164 ** 
ndhK__embedding_939    -84.2404   108.0653  -0.780 0.435699    
ndhK__embedding_517     -9.8507   106.7471  -0.092 0.926478    
ndhK__embedding_506    126.2942   140.7308   0.897 0.369535    
ndhK__embedding_651     51.5243    82.6169   0.624 0.532880    
ndhK__embedding_323     91.6913    90.6985   1.011 0.312086    
petA__embedding_612    334.5739   153.5199   2.179 0.029347 *  
petA__embedding_729   -602.0623   184.7791  -3.258 0.001128 ** 
petA__embedding_897   -453.2711   148.9370  -3.043 0.002350 ** 
petA__embedding_324   -652.7140   156.1078  -4.181 2.94e-05 ***
petA__embedding_319    386.2192   139.2024   2.775 0.005547 ** 
petG__embedding_425     84.5845   262.7228   0.322 0.747500    
petG__embedding_374    -59.6645   315.6244  -0.189 0.850071    
petG__embedding_521    434.0284   318.9684   1.361 0.173655    
petG__embedding_915    615.3282   292.1598   2.106 0.035237 *  
petG__embedding_951    623.7942   217.1642   2.872 0.004088 ** 
petN__embedding_958    -58.0280   363.3937  -0.160 0.873136    
petN__embedding_489   -110.4462   313.8402  -0.352 0.724913    
petN__embedding_834     60.7427   443.6018   0.137 0.891090    
petN__embedding_36    -738.7642   526.0309  -1.404 0.160252    
petN__embedding_365     82.3972   408.3399   0.202 0.840091    
psaA__embedding_484   1108.8366   388.8539   2.852 0.004367 ** 
psaA__embedding_771    364.4961   282.5170   1.290 0.197044    
psaA__embedding_829    237.4691   404.4587   0.587 0.557141    
psaA__embedding_690   -112.0831   297.0699  -0.377 0.705969    
psaA__embedding_434    219.3752   308.5731   0.711 0.477154    
psaB__embedding_765   -281.7155   277.6143  -1.015 0.310258    
psaB__embedding_239   1715.7434   436.0704   3.935 8.44e-05 ***
psaB__embedding_44    -750.2019   342.0719  -2.193 0.028340 *  
psaB__embedding_366     21.2841   256.0890   0.083 0.933765    
psaB__embedding_707    156.7416   338.5362   0.463 0.643384    
psaC__embedding_59    -363.0662   304.2073  -1.193 0.232731    
psaC__embedding_118   -199.1120   360.1609  -0.553 0.580394    
psaC__embedding_459   -387.7371   319.8318  -1.212 0.225443    
psaC__embedding_37      22.7648   314.5797   0.072 0.942313    
psaC__embedding_220    500.4841   534.1604   0.937 0.348822    
psaJ__embedding_168     89.3835   129.4254   0.691 0.489834    
psaJ__embedding_5      348.1794   130.4995   2.668 0.007651 ** 
psaJ__embedding_931    163.9249   120.7750   1.357 0.174748    
psaJ__embedding_281   -194.2151   129.4487  -1.500 0.133587    
psaJ__embedding_189    182.5217   146.1124   1.249 0.211649    
psbA__embedding_514   1285.7125   533.1846   2.411 0.015924 *  
psbA__embedding_553   -541.1976   352.4451  -1.536 0.124705    
psbA__embedding_950    198.7766   342.1091   0.581 0.561242    
psbA__embedding_414    278.7585   374.8898   0.744 0.457165    
psbA__embedding_786    194.5895   375.2246   0.519 0.604064    
psbB__embedding_269    -91.2722   302.1339  -0.302 0.762593    
psbB__embedding_221   1010.9740   296.1797   3.413 0.000646 ***
psbB__embedding_854    226.2997   212.3829   1.066 0.286683    
psbB__embedding_655   1445.8331   273.5239   5.286 1.30e-07 ***
psbB__embedding_217    330.6320   247.2457   1.337 0.181192    
psbC__embedding_438   1593.7468   334.7349   4.761 1.97e-06 ***
psbC__embedding_30     828.2348   233.5209   3.547 0.000393 ***
psbC__embedding_394  -1446.6585   408.0907  -3.545 0.000396 ***
psbC__embedding_836    640.6248   331.8060   1.931 0.053568 .  
psbC__embedding_251   -606.9502   431.3952  -1.407 0.159499    
psbD__embedding_351   -497.3078   298.4608  -1.666 0.095721 .  
psbD__embedding_814   1073.5785   468.3372   2.292 0.021924 *  
psbD__embedding_574   -361.4385   132.2990  -2.732 0.006315 ** 
psbD__embedding_702   -375.5442   474.4349  -0.792 0.428650    
psbD__embedding_352    261.6981   366.9131   0.713 0.475725    
psbE__embedding_261    -60.4732   195.3587  -0.310 0.756915    
psbE__embedding_480    131.1661    90.3575   1.452 0.146659    
psbE__embedding_143    114.7043   274.4804   0.418 0.676039    
psbE__embedding_724    -64.5953   260.7966  -0.248 0.804388    
psbE__embedding_152   -367.9184   257.4361  -1.429 0.153013    
psbF__embedding_163    753.9860   266.0071   2.834 0.004607 ** 
psbF__embedding_595    317.8768   254.1353   1.251 0.211053    
psbF__embedding_406   -990.0347   396.2723  -2.498 0.012505 *  
psbF__embedding_848     77.8062   266.0174   0.292 0.769926    
psbF__embedding_335    375.3744   195.6757   1.918 0.055117 .  
psbH__embedding_712    -12.6540    92.8964  -0.136 0.891656    
psbH__embedding_478   -179.3157   114.4757  -1.566 0.117310    
psbH__embedding_656    -19.7671   103.5579  -0.191 0.848627    
psbH__embedding_305     77.4610    81.8522   0.946 0.344010    
psbH__embedding_481    279.7592    88.2166   3.171 0.001526 ** 
psbI__embedding_490  -1729.5498   414.8573  -4.169 3.11e-05 ***
psbI__embedding_728    -29.2831   319.8111  -0.092 0.927048    
psbI__embedding_1      309.0379   448.6852   0.689 0.491001    
psbI__embedding_890     40.1530   370.3839   0.108 0.913675    
psbI__embedding_151  -1164.5657   352.0678  -3.308 0.000946 ***
psbJ__embedding_72     274.8862   237.2479   1.159 0.246650    
psbJ__embedding_286      8.8962   196.0862   0.045 0.963815    
psbJ__embedding_5      366.8672   225.3730   1.628 0.103619    
psbJ__embedding_24      41.2185    27.4781   1.500 0.133658    
psbJ__embedding_924     53.8107   166.4295   0.323 0.746462    
psbK__embedding_53     190.3357    91.6326   2.077 0.037832 *  
psbK__embedding_723    227.7159    77.4277   2.941 0.003285 ** 
psbK__embedding_255   -127.0001    91.6073  -1.386 0.165694    
psbK__embedding_892    -80.0394    72.4595  -1.105 0.269377    
psbK__embedding_222    -38.7123    77.1417  -0.502 0.615804    
psbM__embedding_53    -113.9512   145.9467  -0.781 0.434969    
psbM__embedding_570    642.2320   196.3048   3.272 0.001076 ** 
psbM__embedding_475    127.5162   118.5945   1.075 0.282319    
psbM__embedding_646    -44.4199   196.6175  -0.226 0.821272    
psbM__embedding_959   -243.3920   124.5491  -1.954 0.050729 .  
psbN__embedding_527    440.9233   211.7174   2.083 0.037333 *  
psbN__embedding_815   -536.4711   219.9086  -2.440 0.014738 *  
psbN__embedding_156     21.7603   255.8526   0.085 0.932225    
psbN__embedding_490    314.1987   233.1926   1.347 0.177913    
psbN__embedding_505     17.6603   295.4896   0.060 0.952344    
psbT__embedding_604     30.2452   171.4804   0.176 0.860004    
psbT__embedding_711    521.1176   150.9538   3.452 0.000560 ***
psbT__embedding_941    252.3240   114.5114   2.203 0.027601 *  
psbT__embedding_106    384.7161   147.7706   2.603 0.009253 ** 
psbT__embedding_803     43.8725   103.3712   0.424 0.671278    
psbZ__embedding_338    203.0167   186.6933   1.087 0.276892    
psbZ__embedding_227    -68.5603   196.9585  -0.348 0.727782    
psbZ__embedding_861     29.6209   145.3293   0.204 0.838502    
psbZ__embedding_90    -144.6589   128.2703  -1.128 0.259467    
 [ reached getOption("max.print") -- omitted 106 rows ]
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 4.896 on 5601 degrees of freedom
Multiple R-squared:  0.4499,    Adjusted R-squared:  0.4199 
F-statistic: 15.02 on 305 and 5601 DF,  p-value: < 2.2e-16

# per-gene variable selection and building a combined design matrix
library(data.table)
library(stats)

setDT(clean_embeds)
setDT(df)

pheno_col <- "pheno_wc2.1_2.5m_bio_8_p50"
n_top <- 1

sel_list <- vector("list", length(genes))
names(sel_list) <- genes

for (gene in genes) {
  gdt <- clean_embeds[Gene == gene, c("ID", embed_cols), with = FALSE]
  gdt <- merge(gdt, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = FALSE, all.y = FALSE)
  
  if (nrow(gdt) < 5) {
    warning(sprintf("Gene %s has only %d rows; skipping.", gene, nrow(gdt)))
    next
  }
  
  cors <- sapply(embed_cols, function(col) cor(gdt[[col]], gdt$pheno, use = "complete.obs"))
  available <- names(cors)[!is.na(cors)]
  k <- min(n_top, length(available))
  top_dims <- names(sort(abs(cors[available]), decreasing = TRUE))[1:k]
  
  sel <- gdt[, c("ID", top_dims), with = FALSE]
  setnames(sel, old = top_dims, new = paste0(gene, "__", top_dims))
  sel_list[[gene]] <- sel
}

Quitting from lines 256-307 [unnamed-chunk-16] (makePairsPostMDSFilter.Rmd)
sel_list <- sel_list[!sapply(sel_list, is.null)]

# merge across all IDs (full outer join instead of intersection)
combined <- Reduce(function(a, b) merge(a, b, by = "ID", all = TRUE), sel_list)

# bring phenotype back
combined <- merge(combined, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = TRUE, all.y = FALSE)

# median imputation for missing predictors (exclude ID + pheno)
pred_cols <- setdiff(names(combined), c("ID", "pheno"))
for (col in pred_cols) {
  med <- median(combined[[col]], na.rm = TRUE)
  combined[is.na(get(col)), (col) := med]
}

# quick checks
cat("Samples (IDs) in combined matrix:", nrow(combined), "\n")
Samples (IDs) in combined matrix: 10857 
cat("Number of predictors:", length(pred_cols), "\n")
Number of predictors: 61 
# fit a plain linear model (may be unstable if predictors >> samples)
# remove ID column for model
lm_formula <- as.formula(paste("pheno ~", paste(setdiff(colnames(combined), c("ID", "pheno")), collapse = " + ")))
lm_fit <- try(lm(lm_formula, data = combined), silent = TRUE)

if (inherits(lm_fit, "try-error")) {
  warning("lm failed (probably too many predictors). See glmnet example below.")
} else {
  print(summary(lm_fit))
}

Call:
lm(formula = lm_formula, data = combined)

Residuals:
     Min       1Q   Median       3Q      Max 
-28.0942  -3.2983   0.8184   3.8462  19.7686 

Coefficients:
                      Estimate Std. Error t value Pr(>|t|)    
(Intercept)             76.205      5.541  13.752  < 2e-16 ***
atpA__embedding_876    393.909     85.681   4.597 4.33e-06 ***
atpB__embedding_224     20.269     57.698   0.351 0.725374    
atpE__embedding_781    -69.591     48.947  -1.422 0.155127    
atpF__embedding_739   -103.899     44.859  -2.316 0.020570 *  
atpH__embedding_248    214.038    136.790   1.565 0.117679    
atpI__embedding_282    -13.242     47.525  -0.279 0.780532    
ccsA__embedding_736   -115.425     31.021  -3.721 0.000200 ***
cemA__embedding_764     -5.409     40.665  -0.133 0.894194    
matK__embedding_867    -33.408     30.345  -1.101 0.270935    
ndhA__embedding_204      7.551     38.344   0.197 0.843889    
ndhB__embedding_16      52.587     65.182   0.807 0.419815    
ndhC__embedding_172    -67.894     33.199  -2.045 0.040869 *  
ndhD__embedding_610   -129.337     33.226  -3.893 9.98e-05 ***
ndhE__embedding_394   -161.721     49.409  -3.273 0.001067 ** 
ndhG__embedding_936    189.680     26.602   7.130 1.07e-12 ***
ndhH__embedding_448   -327.157    101.589  -3.220 0.001284 ** 
ndhI__embedding_356   -118.316     39.630  -2.986 0.002837 ** 
ndhJ__embedding_910    -17.921     54.810  -0.327 0.743706    
ndhK__embedding_939     49.281     38.795   1.270 0.204003    
petA__embedding_612   -192.007     47.767  -4.020 5.87e-05 ***
petG__embedding_425     76.336     30.410   2.510 0.012079 *  
petN__embedding_958   -364.799    115.487  -3.159 0.001589 ** 
psaA__embedding_484    617.128    138.603   4.452 8.57e-06 ***
psaB__embedding_765     44.646    121.998   0.366 0.714402    
psaC__embedding_59    -128.172     99.698  -1.286 0.198608    
psaJ__embedding_168    391.925     51.369   7.630 2.55e-14 ***
psbA__embedding_514   1464.388    200.833   7.292 3.28e-13 ***
psbB__embedding_269    723.484     96.089   7.529 5.51e-14 ***
psbC__embedding_438    767.096    114.227   6.716 1.97e-11 ***
psbD__embedding_351   -532.484    105.238  -5.060 4.27e-07 ***
psbE__embedding_261    259.766     85.849   3.026 0.002485 ** 
psbF__embedding_163    891.369    115.667   7.706 1.41e-14 ***
psbH__embedding_712   -145.498     37.374  -3.893 9.96e-05 ***
psbI__embedding_490   -994.834    143.682  -6.924 4.64e-12 ***
psbJ__embedding_72     988.822     67.731  14.599  < 2e-16 ***
psbK__embedding_53     127.885     40.554   3.153 0.001618 ** 
psbM__embedding_53     212.570     51.348   4.140 3.50e-05 ***
psbN__embedding_527    207.500     45.044   4.607 4.14e-06 ***
psbT__embedding_604    -87.724     42.380  -2.070 0.038485 *  
psbZ__embedding_338    336.252     58.052   5.792 7.14e-09 ***
rbcL__embedding_425  -1219.722     78.300 -15.578  < 2e-16 ***
rpl14__embedding_327    19.500     28.170   0.692 0.488828    
rpl16__embedding_355   -38.872     30.023  -1.295 0.195434    
rpl2__embedding_422     15.997     34.288   0.467 0.640832    
rpl20__embedding_422   -32.124     42.109  -0.763 0.445555    
rpl23__embedding_434   -12.479     31.976  -0.390 0.696355    
rpl33__embedding_446    54.198     27.397   1.978 0.047925 *  
rpl36__embedding_661    37.785     34.438   1.097 0.272593    
rpoA__embedding_201     38.182     28.417   1.344 0.179095    
rpoB__embedding_774   -327.040     55.123  -5.933 3.07e-09 ***
rpoC1__embedding_238  -141.284     66.400  -2.128 0.033380 *  
rps11__embedding_254   -12.219     32.930  -0.371 0.710594    
rps14__embedding_437   157.310     28.094   5.599 2.20e-08 ***
rps18__embedding_255    81.100     24.592   3.298 0.000977 ***
rps19__embedding_469    70.034     19.765   3.543 0.000397 ***
rps3__embedding_349    -62.348     27.691  -2.252 0.024371 *  
rps4__embedding_363      9.651     38.608   0.250 0.802616    
rps7__embedding_838   -183.027     78.300  -2.338 0.019431 *  
rps8__embedding_78      79.731     30.892   2.581 0.009865 ** 
ycf3__embedding_29    -110.264     61.399  -1.796 0.072544 .  
ycf4__embedding_373      1.939     57.604   0.034 0.973143    
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 5.434 on 10795 degrees of freedom
Multiple R-squared:  0.3022,    Adjusted R-squared:  0.2983 
F-statistic: 76.65 on 61 and 10795 DF,  p-value: < 2.2e-16
plot(lm_fit$fitted.values, combined$pheno, main="pheno ~ top5 embs per gene, n=10857")

library(ggplot2)

set.seed(123)  # reproducibility

# predictors
pred_cols <- setdiff(names(combined), c("ID", "pheno"))

## ---- 1. Random 10% holdout ----
idx_test <- sample(seq_len(nrow(combined)), size = ceiling(0.1 * nrow(combined)))
train1 <- combined[-idx_test]
test1  <- combined[idx_test]

fit1 <- lm(pheno ~ ., data = train1[, c("pheno", pred_cols), with = FALSE])
pred1 <- predict(fit1, newdata = test1)

# plot
ggplot(data.frame(obs = test1$pheno, pred = pred1), aes(x = pred, y = obs)) +
  geom_point(alpha = 0.5) +
  geom_abline(color = "red", linetype = "dashed") +
  labs(title = "Random 10% Holdout",
       x = "Predicted phenotype",
       y = "Observed phenotype") +
  theme_minimal()



## ---- 2. Poaceae holdout ----
# get Poaceae IDs
poaceae_ids <- df[grepl("Poales", Taxonomy), ID]

train2 <- combined[!ID %in% poaceae_ids]
test2  <- combined[ID %in% poaceae_ids]

fit2 <- lm(pheno ~ ., data = train2[, c("pheno", pred_cols), with = FALSE])
pred2 <- predict(fit2, newdata = test2)

# plot
ggplot(data.frame(obs = test2$pheno, pred = pred2), aes(x = pred, y = obs)) +
  geom_point(alpha = 0.5, color = "darkgreen") +
  geom_abline(color = "red", linetype = "dashed") +
  labs(title = "Poales Holdout",
       x = "Predicted phenotype",
       y = "Observed phenotype") +
  theme_minimal()

rmse     <- function(a,b) sqrt(mean((a - b)^2))

plot(pred2,test2$pheno, main="bio8 ~ top1 embs/gene, holdout Poales",
      xlab="Predicted",
      ylab="Observed")
abline(a=0,b=1, col="coral")
text(25, 5,
     paste0("spearman=",round(cor(pred2,test2$pheno), 3)),col="red")
text(25,2,
     paste0("RMSE=",round(rmse(pred2,test2$pheno), 3)),col="red")

for (ord in unique(df$Order)) {
  # get Poaceae IDs
  poaceae_ids <- df[grepl(ord, Taxonomy), ID]
  
  train2 <- combined[!ID %in% poaceae_ids]
  test2  <- combined[ID %in% poaceae_ids]
  
  fit2 <- lm(pheno ~ ., data = train2[, c("pheno", pred_cols), with = FALSE])
  pred2 <- predict(fit2, newdata = test2)
  
  # plot
  ggplot(data.frame(obs = test2$pheno, pred = pred2), aes(x = pred, y = obs)) +
    geom_point(alpha = 0.5, color = "darkgreen") +
    geom_abline(color = "red", linetype = "dashed") +
    labs(title = paste0(ord, " Holdout"),
         x = "Predicted phenotype",
         y = "Observed phenotype") +
    theme_minimal()
  
  rmse     <- function(a,b) sqrt(mean((a - b)^2))
  
  plot(pred2,test2$pheno, main=paste0("bio8 ~ top5 embs/gene, holdout ", ord),
        xlab="Predicted",
        ylab="Observed")
  abline(a=0,b=1, col="coral")
  text(quantile(pred2, 0.75), quantile(test2$pheno, 0.25),
       paste0("spearman=",round(cor(pred2,test2$pheno), 3)),col="red")
  text(quantile(pred2, 0.75), quantile(test2$pheno, 0.25)-2,
       paste0("RMSE=",round(rmse(pred2,test2$pheno), 3)),col="red")
}

library(data.table)
library(ggplot2)

setDT(clean_embeds)
setDT(df)

pheno_col <- "pheno_wc2.1_2.5m_bio_8_p50"
genes <- unique(clean_embeds$Gene)

rmse <- function(a, b) sqrt(mean((a - b)^2))

# store results
results <- data.table(n_top = integer(),
                      split = character(),
                      spearman = numeric(),
                      rmse = numeric())

set.seed(123)

for (n_top in 1:100) {
  sel_list <- vector("list", length(genes))
  names(sel_list) <- genes
  
  for (gene in genes) {
    gdt <- clean_embeds[Gene == gene, c("ID", embed_cols), with = FALSE]
    gdt <- merge(gdt, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = FALSE, all.y = FALSE)
    
    if (nrow(gdt) < 5) next
    
    cors <- sapply(embed_cols, function(col) cor(gdt[[col]], gdt$pheno, use = "complete.obs"))
    available <- names(cors)[!is.na(cors)]
    k <- min(n_top, length(available))
    top_dims <- names(sort(abs(cors[available]), decreasing = TRUE))[1:k]
    
    sel <- gdt[, c("ID", top_dims), with = FALSE]
    setnames(sel, old = top_dims, new = paste0(gene, "__", top_dims))
    sel_list[[gene]] <- sel
  }
  
  sel_list <- sel_list[!sapply(sel_list, is.null)]
  if (length(sel_list) == 0) next
  
  combined <- Reduce(function(a, b) merge(a, b, by = "ID", all = TRUE), sel_list)
  combined <- merge(combined, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = TRUE, all.y = FALSE)
  
  # median impute
  pred_cols <- setdiff(names(combined), c("ID", "pheno"))
  for (col in pred_cols) {
    med <- median(combined[[col]], na.rm = TRUE)
    combined[is.na(get(col)), (col) := med]
  }
  
  ## ---- 1. Random 10% holdout ----
  idx_test <- sample(seq_len(nrow(combined)), size = ceiling(0.1 * nrow(combined)))
  train1 <- combined[-idx_test]
  test1  <- combined[idx_test]
  
  fit1 <- lm(pheno ~ ., data = train1[, c("pheno", pred_cols), with = FALSE])
  pred1 <- predict(fit1, newdata = test1)
  
  results <- rbind(results, data.table(
    n_top = n_top,
    split = "random10",
    spearman = cor(pred1, test1$pheno, method = "spearman"),
    rmse = rmse(pred1, test1$pheno)
  ))
  
  ## ---- 2. Poaceae holdout ----
  poaceae_ids <- df[grepl("Poaceae", Taxonomy), ID]
  train2 <- combined[!ID %in% poaceae_ids]
  test2  <- combined[ID %in% poaceae_ids]
  
  if (nrow(test2) > 0 && nrow(train2) > 10) {
    fit2 <- lm(pheno ~ ., data = train2[, c("pheno", pred_cols), with = FALSE])
    pred2 <- predict(fit2, newdata = test2)
    
    results <- rbind(results, data.table(
      n_top = n_top,
      split = "poaceae",
      spearman = cor(pred2, test2$pheno, method = "spearman"),
      rmse = rmse(pred2, test2$pheno)
    ))
  }
}

Quitting from lines 420-514 [unnamed-chunk-21] (makePairsPostMDSFilter.Rmd)
---
title: "find poairs for rotholog contrast"
output:
  pdf_document: default
  html_notebook: default
editor_options:
  chunk_output_type: inline
---
```{r setup, include=FALSE}
# Load knitr
library(knitr)
library(data.table)
library(arrow)
# Set the working directory for all chunks
opts_knit$set(root.dir = "/local/workdir/hdd29/chloroplast_genome_evaluation")
```

We can read in the embedding file, which has the outlier protein sequences detected by manual inspection of MDS plots

```{r}
embeds <- readRDS("data/tmp/embeds_with_mds.rds")
table(embeds$ManualOutlier)
dim(embeds)
embed_cols <- grep("embedding", colnames(embeds), value=TRUE)

clean_embeds <- embeds[ManualOutlier == FALSE]
head(clean_embeds[,1:5])
#write_parquet(clean_embeds, "data/clean_embeds.parquet")

#clean_seq <- clean_embeds[,c("ID", "Taxonomy", "Gene", "AA_seq")]
#write_parquet(clean_seq, "data/clean_AA_seqs.parquet")
#head(clean_seq)
```
And now we can read in the data with phenotypic values 

```{r}
df <- read_parquet("data/processed_data.parquet")
setDT(df)

#pheno <- clean_data[,c("ID", "pheno_Topt_site_p50")]
  
#write_parquet(pheno, "data/pheno_topt_clean.parquet")

```

Now, in the meantime, I wish to analyze the correlation between the embeddings and my phenotypes. 
```{r}
genes <- unique(clean_embeds$Gene)
all_cors <- list()
for (gene in genes) {
  gene_data <- clean_embeds[grep(gene, clean_embeds$Gene),]
  pheno <- df[,c("ID", "pheno_wc2.1_2.5m_bio_8_p50")]
  
  gene_data <- merge(gene_data, pheno, by="ID")
  
  cors <- sapply(embed_cols, function(col)
    cor(gene_data[[col]], gene_data$pheno_wc2.1_2.5m_bio_8_p50, use="complete.obs")
  )
  hist(cors, main=gene)
  all_cors[[gene]] <- cors
}

```

```{r}
par(mfrow=c(1,2))
hist(all_cors$psbN, main="Cor of psbN embeds with bio8")
hist(all_cors$rbcL, main="rbcL")
```
```{r}
gene_data <- clean_embeds[grep("psaC", clean_embeds$Gene),]
par(mfrow=c(1,2))
hist(gene_data$MDS1)
plot(gene_data$MDS1, gene_data$MDS2, main="psaC MDS results")
#hist(gene_data$MDS2)
par(mfrow=c(1,1))
for (ord in unique(df$Order)) {
  order_ids <- df[grep(ord, df$Order),"ID"]
  order_subset <- gene_data[gene_data$ID %in% order_ids$ID, ]
  hist(order_subset$MDS1,main=ord,xlim=c(-.1,.1))
}
```
```{r}
gene_data <- clean_embeds[grep("psaC", clean_embeds$Gene),]
merged <- merge(gene_data, df[,c("ID","Order")], by="ID")

boxplot(MDS1 ~ Order, data=merged,
        main="psaC MDS1 by Order",
        las=2, outline=FALSE)
```
```{r}
library(pheatmap)
mat <- do.call(rbind, all_cors)
rownames(mat) <- names(all_cors)
pheatmap(mat, color=colorRampPalette(c("blue","white","red"))(100))
```
```{r}
# build correlation matrix: genes x embedding dimensions
mat <- do.call(rbind, all_cors)
rownames(mat) <- names(all_cors)

# similarity between genes: correlation of their correlation profiles
gene_sim <- cor(t(mat), use="pairwise.complete.obs")

# hierarchical clustering heatmap
pheatmap(gene_sim, 
         main="Similarity of gene embeddings wrt bio8",
         color=colorRampPalette(c("blue","white","red"))(100))

hist(gene_sim)
```

```{r}
offdiag <- gene_sim[upper.tri(gene_sim)]

hist(offdiag, breaks=30,
     main="Stability of embedding correlations across genes",
     xlab="Pairwise correlation",
     col="skyblue")
abline(v=mean(offdiag, na.rm=TRUE), col="red", lwd=2)
```

```{r}
hist(length(clean_embeds$psaC_CDS)/3)
```
```{r}
mat <- do.call(rbind, all_cors)
rownames(mat) <- names(all_cors)

# per-embedding stats across genes
embed_stats <- data.frame(
  dim = colnames(mat),
  mean_cor = apply(mat, 2, mean, na.rm=TRUE),
  mean_abs_cor = apply(mat, 2, function(x) mean(abs(x), na.rm=TRUE)),
  sd_cor = apply(mat, 2, sd, na.rm=TRUE)
)

plot(embed_stats$mean_cor, embed_stats$sd_cor,
     xlab="mean of correlation across genes",
     ylab="sd of correlation across genes")
abline(a=0,b=1)
# pick embedding with strongest + most stable signal
best <- embed_stats[order(-embed_stats$mean_abs_cor, embed_stats$sd_cor), ][1, ]
best
```

```{r}
library(stats)

results <- list()
for (gene in genes) {
  gene_data <- clean_embeds[Gene == gene]
  pheno <- df[,c("ID","pheno_wc2.1_2.5m_bio_8_p50")]
  merged <- merge(gene_data, pheno, by="ID")
  
  cors <- sapply(embed_cols, function(col)
    cor(merged[[col]], merged$pheno_wc2.1_2.5m_bio_8_p50, use="complete.obs"))
  
  top_dims <- names(sort(abs(cors), decreasing=TRUE))[1:5]
  formula_str <- paste("pheno_wc2.1_2.5m_bio_8_p50 ~", paste(top_dims, collapse=" + "))
  fit <- lm(as.formula(formula_str), data=merged)
  
  plot(fit$fitted.values, merged$pheno_wc2.1_2.5m_bio_8_p50)
  
  results[[gene]] <- list(
    cors=cors,
    top_dims=top_dims,
    model=summary(fit)
  )
}

results$psaC$top_dims
results$psaC$model

```


```{r}
# per-gene variable selection and building a combined design matrix
library(data.table)
library(stats)

# ensure data.tables
setDT(clean_embeds)
setDT(df)

# phenotype column name (adjust if you want a different pheno)
pheno_col <- "pheno_wc2.1_2.5m_bio_8_p50"

# number of top dims per gene
n_top <- 5

# container for per-gene selected data.tables (ID + renamed top embeds)
sel_list <- vector("list", length(genes))
names(sel_list) <- genes

for (gene in genes) {
  # subset gene
  gdt <- clean_embeds[Gene == gene, c("ID", embed_cols), with = FALSE]
  # merge with phenotype (inner join to ensure measurable correlation)
  gdt <- merge(gdt, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = FALSE, all.y = FALSE)
  # if too few rows, skip (or store NA)
  if (nrow(gdt) < 5) {
    warning(sprintf("Gene %s has only %d rows; skipping.", gene, nrow(gdt)))
    next
  }
  # compute correlations (use complete.obs)
  cors <- sapply(embed_cols, function(col) cor(gdt[[col]], gdt$pheno, use = "complete.obs"))
  # pick top dims by absolute correlation (handle if fewer than n_top dims available)
  available <- names(cors)[!is.na(cors)]
  k <- min(n_top, length(available))
  top_dims <- names(sort(abs(cors[available]), decreasing = TRUE))[1:k]
  # select ID + these dims and rename dims to gene__dim
  sel <- gdt[, c("ID", top_dims), with = FALSE]
  newnames <- setNames(top_dims, paste0(gene, "__", top_dims))
  setnames(sel, old = top_dims, new = paste0(gene, "__", top_dims))
  # keep only ID + renamed columns (drop phenotype copy)
  sel <- sel[, c("ID", names(newnames)), with = FALSE]
  sel_list[[gene]] <- sel
}

# remove genes we skipped
sel_list <- sel_list[!sapply(sel_list, is.null)]

# merge all per-gene tables by ID, keeping only IDs present in ALL (intersection)
if (length(sel_list) == 0) stop("No genes with selected dims found.")
combined <- Reduce(function(a, b) merge(a, b, by = "ID", all = FALSE), sel_list)

# bring phenotype back
combined <- merge(combined, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = TRUE, all.y = FALSE)

# quick checks
cat("Samples (IDs) in combined matrix:", nrow(combined), "\n")
cat("Number of predictors:", ncol(combined) - 2, "(ID and pheno excluded)\n")  # -2 for ID and pheno
```


```{r}
# fit a plain linear model (may be unstable if predictors >> samples)
# remove ID column for model
lm_formula <- as.formula(paste("pheno ~", paste(setdiff(colnames(combined), c("ID", "pheno")), collapse = " + ")))
lm_fit <- try(lm(lm_formula, data = combined), silent = TRUE)

if (inherits(lm_fit, "try-error")) {
  warning("lm failed (probably too many predictors). See glmnet example below.")
} else {
  print(summary(lm_fit))
}
```

```{r}
plot(lm_fit$fitted.values, combined$pheno)
```

```{r}
# per-gene variable selection and building a combined design matrix
library(data.table)
library(stats)

setDT(clean_embeds)
setDT(df)

pheno_col <- "pheno_wc2.1_2.5m_bio_8_p50"
n_top <- 1

sel_list <- vector("list", length(genes))
names(sel_list) <- genes

for (gene in genes) {
  gdt <- clean_embeds[Gene == gene, c("ID", embed_cols), with = FALSE]
  gdt <- merge(gdt, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = FALSE, all.y = FALSE)
  
  if (nrow(gdt) < 5) {
    warning(sprintf("Gene %s has only %d rows; skipping.", gene, nrow(gdt)))
    next
  }
  
  cors <- sapply(embed_cols, function(col) cor(gdt[[col]], gdt$pheno, use = "complete.obs"))
  available <- names(cors)[!is.na(cors)]
  k <- min(n_top, length(available))
  top_dims <- names(sort(abs(cors[available]), decreasing = TRUE))[1:k]
  
  sel <- gdt[, c("ID", top_dims), with = FALSE]
  setnames(sel, old = top_dims, new = paste0(gene, "__", top_dims))
  sel_list[[gene]] <- sel
}

sel_list <- sel_list[!sapply(sel_list, is.null)]

# merge across all IDs (full outer join instead of intersection)
combined <- Reduce(function(a, b) merge(a, b, by = "ID", all = TRUE), sel_list)

# bring phenotype back
combined <- merge(combined, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = TRUE, all.y = FALSE)

# median imputation for missing predictors (exclude ID + pheno)
pred_cols <- setdiff(names(combined), c("ID", "pheno"))
for (col in pred_cols) {
  med <- median(combined[[col]], na.rm = TRUE)
  combined[is.na(get(col)), (col) := med]
}

# quick checks
cat("Samples (IDs) in combined matrix:", nrow(combined), "\n")
cat("Number of predictors:", length(pred_cols), "\n")

```

```{r}
# fit a plain linear model (may be unstable if predictors >> samples)
# remove ID column for model
lm_formula <- as.formula(paste("pheno ~", paste(setdiff(colnames(combined), c("ID", "pheno")), collapse = " + ")))
lm_fit <- try(lm(lm_formula, data = combined), silent = TRUE)

if (inherits(lm_fit, "try-error")) {
  warning("lm failed (probably too many predictors). See glmnet example below.")
} else {
  print(summary(lm_fit))
}

plot(lm_fit$fitted.values, combined$pheno, main="pheno ~ top5 embs per gene, n=10857")
```



```{r}
library(ggplot2)

set.seed(123)  # reproducibility

# predictors
pred_cols <- setdiff(names(combined), c("ID", "pheno"))

## ---- 1. Random 10% holdout ----
idx_test <- sample(seq_len(nrow(combined)), size = ceiling(0.1 * nrow(combined)))
train1 <- combined[-idx_test]
test1  <- combined[idx_test]

fit1 <- lm(pheno ~ ., data = train1[, c("pheno", pred_cols), with = FALSE])
pred1 <- predict(fit1, newdata = test1)

# plot
ggplot(data.frame(obs = test1$pheno, pred = pred1), aes(x = pred, y = obs)) +
  geom_point(alpha = 0.5) +
  geom_abline(color = "red", linetype = "dashed") +
  labs(title = "Random 10% Holdout",
       x = "Predicted phenotype",
       y = "Observed phenotype") +
  theme_minimal()


## ---- 2. Poaceae holdout ----
# get Poaceae IDs
poaceae_ids <- df[grepl("Poales", Taxonomy), ID]

train2 <- combined[!ID %in% poaceae_ids]
test2  <- combined[ID %in% poaceae_ids]

fit2 <- lm(pheno ~ ., data = train2[, c("pheno", pred_cols), with = FALSE])
pred2 <- predict(fit2, newdata = test2)

# plot
ggplot(data.frame(obs = test2$pheno, pred = pred2), aes(x = pred, y = obs)) +
  geom_point(alpha = 0.5, color = "darkgreen") +
  geom_abline(color = "red", linetype = "dashed") +
  labs(title = "Poales Holdout",
       x = "Predicted phenotype",
       y = "Observed phenotype") +
  theme_minimal()

```

```{r}
rmse     <- function(a,b) sqrt(mean((a - b)^2))

plot(pred2,test2$pheno, main="bio8 ~ top1 embs/gene, holdout Poales",
      xlab="Predicted",
      ylab="Observed")
abline(a=0,b=1, col="coral")
text(25, 5,
     paste0("spearman=",round(cor(pred2,test2$pheno), 3)),col="red")
text(25,2,
     paste0("RMSE=",round(rmse(pred2,test2$pheno), 3)),col="red")
```

```{r}
for (ord in unique(df$Order)) {
  # get Poaceae IDs
  poaceae_ids <- df[grepl(ord, Taxonomy), ID]
  
  train2 <- combined[!ID %in% poaceae_ids]
  test2  <- combined[ID %in% poaceae_ids]
  
  fit2 <- lm(pheno ~ ., data = train2[, c("pheno", pred_cols), with = FALSE])
  pred2 <- predict(fit2, newdata = test2)
  
  # plot
  ggplot(data.frame(obs = test2$pheno, pred = pred2), aes(x = pred, y = obs)) +
    geom_point(alpha = 0.5, color = "darkgreen") +
    geom_abline(color = "red", linetype = "dashed") +
    labs(title = paste0(ord, " Holdout"),
         x = "Predicted phenotype",
         y = "Observed phenotype") +
    theme_minimal()
  
  rmse     <- function(a,b) sqrt(mean((a - b)^2))
  
  plot(pred2,test2$pheno, main=paste0("bio8 ~ top1 embs/gene, holdout ", ord),
        xlab="Predicted",
        ylab="Observed")
  abline(a=0,b=1, col="coral")
  text(quantile(pred2, 0.75), quantile(test2$pheno, 0.25),
       paste0("spearman=",round(cor(pred2,test2$pheno), 3)),col="red")
  text(quantile(pred2, 0.75), quantile(test2$pheno, 0.25)-2,
       paste0("RMSE=",round(rmse(pred2,test2$pheno), 3)),col="red")
}
```


```{r}
# variable selection: pick embedding dims by mean correlation across Orders
setDT(clean_embeds)

pheno_col <- "pheno_wc2.1_2.5m_bio_8_p50"
orders <- unique(df$Order)
n_top <- 5
sel_list <- vector("list", length(genes))
names(sel_list) <- genes

for (gene in genes) {
  gdt <- clean_embeds[Gene == gene, c("ID", embed_cols), with = FALSE]
  gdt <- merge(gdt, df[, .(ID, Order, pheno = get(pheno_col))], by = "ID")
  if (nrow(gdt) < 5) next
  
  # order-wise correlations
  cors_by_order <- sapply(embed_cols, function(col) {
    vals <- sapply(orders, function(ord) {
      sub <- gdt[Order == ord]
      if (nrow(sub) < 3) return(NA_real_)
      cor(sub[[col]], sub$pheno, use = "complete.obs")
    })
    mean(vals, na.rm = TRUE)
  })
  
  hist(cors_by_order, main=gene)
  
  available <- names(cors_by_order)[!is.na(cors_by_order)]
  k <- min(n_top, length(available))
  top_dims <- names(sort(abs(cors_by_order[available]), decreasing = TRUE))[1:k]
  
  sel <- gdt[, c("ID", top_dims), with = FALSE]
  setnames(sel, old = top_dims, new = paste0(gene, "__", top_dims))
  sel_list[[gene]] <- sel
}

# drop skipped genes
sel_list <- sel_list[!sapply(sel_list, is.null)]

# merge across all IDs (full outer join)
combined <- Reduce(function(a, b) merge(a, b, by = "ID", all = TRUE), sel_list)

# bring phenotype back
combined <- merge(combined, df[, .(ID, pheno = get(pheno_col))], by = "ID", all.x = TRUE)

# median impute missing predictors
pred_cols <- setdiff(names(combined), c("ID", "pheno"))
for (col in pred_cols) {
  med <- median(combined[[col]], na.rm = TRUE)
  combined[is.na(get(col)), (col) := med]
}

cat("Samples in combined matrix:", nrow(combined), "\n")
cat("Number of predictors:", length(pred_cols), "\n")



```

